import sys
import heapq

# Load ppr list
output_folder = "../processed_data"
datasets = ['cora']
dataset2nclasses = {'cora': 7, 'pubmed': 3, 'ogbn-products': 47, 'ogbn-arxiv': 40}

for dataset in datasets:
    
    ppr_file = f"{output_folder}/{dataset}/{dataset}_ppradj_list.txt"
    ppr_list = []
    with open(ppr_file, 'r') as f:
        for i in f:
            i = i.strip().split('\t')
            target = int(i[0])
            neighbors = [int(x) for x in i[1].split(' ')][:20]
            ppr_list.append((target, neighbors))
        
    print(f"Loaded PPR list for {dataset}: {len(ppr_list)} entries.")

    title_file = f"{output_folder}/{dataset}/{dataset}_title_list.txt"
    title_list = []
    with open(title_file, 'r') as f:
        for i in f:
            i = i.strip()
            title_list.append(i)
    print(f"Loaded title list for {dataset}: {len(title_list)} entries.")

    if dataset in ['cora', 'pubmed']:
        seeds = list(range(5))
    else:
        seeds = [0]
    for seed in seeds:
        # load prediction raw list
        pred_file = f"{output_folder}/{dataset}/{dataset}_{seed}_gnn_output_raw_probability_list.txt"
        pred_list = []
        full_pred_list = []
        with open(pred_file, 'r') as f:
            for i in f:
                i = i.strip().split('\t')[:2]
                pred_list.append([x.split('|')[0] for x in i])
                full_pred_list.append(i[0].split('|'))
        print(f"Loaded prediction list for {dataset} with seed {seed}: {len(pred_list)} entries.")

        full_pred_list = [[x[0], float(x[1])] for x in full_pred_list] # node_id, class, confidence
        class2id_confidence = {}
        for node_id, class_id_confidence in enumerate(full_pred_list):
            class_name, confidence = class_id_confidence
            if class_name not in class2id_confidence:
                class2id_confidence[class_name] = []
            class2id_confidence[class_name].append((node_id, confidence))
        # for class_name in class2id_confidence:
            # print(f"Class: {class_name}, # samples: {len(class2id_confidence[class_name])}")
        class2prototype = {}
        for class_name, id_confidence_list in class2id_confidence.items():
            class2prototype[class_name] = heapq.nlargest(10, id_confidence_list, key=lambda x: x[1])
        print(f"Class prototypes for {dataset} with seed {seed:02d} loaded.")
        # print(class2prototype)

        prototype_ids = []
        for class_name, prototypes in class2prototype.items():
            prototype_ids.extend([x[0] for x in prototypes])

        """ Start filtering the PPR list based on predictions """
        filtered_ppr_list = []
        for target, neighbors in ppr_list:
            if target in prototype_ids:
                filtered_neighbors = []
                for neighbor in neighbors:
                    groundtruth = pred_list[target][0]
                    if groundtruth in pred_list[neighbor]:
                        filtered_neighbors.append(neighbor)
                filtered_ppr_list.append(filtered_neighbors[:2])
        print(f"Filtered PPR list for {dataset} with seed {seed:02d} loaded.")
        
        ptotype_passage_file = f"{output_folder}/{dataset}/{dataset}_{seed}_prototype_passage_list.txt"
        with open(ptotype_passage_file, 'w') as f:
            for i in filtered_ppr_list:
                f.write(', '.join([f'({title_list[x]})' for x in i]) + '\n')
        print(f"Filtered PPR list saved to {ptotype_passage_file}.")